package com.individual.spring.cloud.gateway.demo.filters;

import io.github.resilience4j.ratelimiter.RateLimiter;
import io.github.resilience4j.ratelimiter.RateLimiterConfig;
import io.github.resilience4j.ratelimiter.RateLimiterRegistry;
import io.vavr.control.Try;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.NettyWriteResponseFilter;
import org.springframework.cloud.gateway.filter.factory.rewrite.ModifyResponseBodyGatewayFilterFactory;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.time.Duration;
import java.util.concurrent.Callable;

@Slf4j
@Component
@RequiredArgsConstructor
public class PreFilter implements GlobalFilter, Ordered {

    private final ModifyResponseBodyGatewayFilterFactory modifyResponseBodyGatewayFilterFactory;
    private final BodyRewrite bodyRewrite;
    private final RateLimiterRegistry rateLimiterRegistry;

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        log.info("pre filter start");
        Boolean limitResult = rateLimit(exchange.getRequest().getPath().toString());
        if (!limitResult) {
            ServerHttpResponse response = exchange.getResponse();
            response.setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
            response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
            DataBuffer dataBuffer = response.bufferFactory().wrap("too many requests".getBytes());
            return response.writeWith(Mono.just(dataBuffer));
        }
        GatewayFilter delegate = modifyResponseBodyGatewayFilterFactory.apply(new ModifyResponseBodyGatewayFilterFactory.Config()
                .setRewriteFunction(byte[].class, byte[].class, bodyRewrite));
        return delegate.filter(exchange, chain).then(Mono.fromRunnable(() -> {
            log.info("pre filter end");
        }));
    }

    @Override
    public int getOrder() {
        return NettyWriteResponseFilter.WRITE_RESPONSE_FILTER_ORDER - 1;
    }

    private Boolean rateLimit(String path) {
        System.out.println(path);
        RateLimiter limiter = rateLimiterRegistry.rateLimiter(path, RateLimiterConfig.custom()
                .limitForPeriod(1)
                .limitRefreshPeriod(Duration.ofMillis(60000))
                .timeoutDuration(Duration.ofMillis(1000))
                .build());
        Callable<Boolean> call = RateLimiter.decorateCallable(limiter, () -> true);
        Try<Boolean> tryResult = Try.of(call::call).recover(ex -> false);
        return tryResult.get();
    }

}
